package com.framework.components.ratelimiter;

import com.framework.components.redis.RedisHelper;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Response;
import redis.clients.jedis.TransactionBlock;
import redis.clients.jedis.exceptions.JedisException;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;

/**
 * Created by yangliu on 2017/8/31.
 */
public class LuaRateLimiterUtil {
    private String scriptSha1;
    private static LuaRateLimiterUtil instance = null;

    private LuaRateLimiterUtil() {

    }

    public static LuaRateLimiterUtil getInstance() {
        if(instance == null) {
            synchronized(LuaRateLimiterUtil.class) {
                if(instance == null) {
                    instance = new LuaRateLimiterUtil();
                    Jedis jedis = RedisHelper.getInstance().getJedisWithOutShare();
                    try {
                        String luaScript  = read(LuaRateLimiterUtil.class.getResource("rateLimiter.lua").getPath().substring(1));
                        instance.scriptSha1 = jedis.scriptLoad(luaScript);
                    } catch (IOException e) {
                        instance = null;
                        e.printStackTrace();
                    }finally {
                        RedisHelper.getInstance().returnResource(jedis);
                    }
                }
            }
        }

        return instance;
    }

    public final List<Object> multi(final String identity,final long intervalInMills, final long capacity) {
        String key = "rate:limiter:" + intervalInMills + ":" + capacity + ":" + identity;
        Jedis jedis = RedisHelper.getInstance().getJedisWithOutShare();
        long intervalPerPermit= intervalInMills / capacity;
        return jedis.multi(new TransactionBlock() {
            @Override
            public void execute() throws JedisException {
                Response<Map<String, String>> response = super.hgetAll(key);
                Map<String, String> map = response.get();
                long currentTokens = 0;
                long refillTime = System.currentTimeMillis();
                if(map.isEmpty()){
                    currentTokens = capacity;
                    super.hset(key, "lastRefillTime", String.valueOf(refillTime));
                }else{
                    long lastRefillTime = Long.parseLong(map.get("2"));
                    long tokensRemaining = Long.parseLong(map.get("4"));
                    if(refillTime > lastRefillTime){
                        long intervalSinceLast = refillTime - lastRefillTime;
                        if(intervalSinceLast > intervalInMills){
                            currentTokens = capacity;
                            super.hset(key, "lastRefillTime", String.valueOf(refillTime));
                        }else{
                            double grantedTokens = Math.floor(intervalSinceLast / intervalPerPermit);
                            if(grantedTokens>0){
                                long padMillis = intervalSinceLast%intervalPerPermit;
                                super.hset(key, "lastRefillTime", String.valueOf(refillTime - padMillis));
                            }
                            currentTokens = (long) Math.min(grantedTokens + tokensRemaining, capacity);

                        }
                    }else{
                        currentTokens = tokensRemaining;

                    }
                    if (currentTokens == 0){
                        super.hset(key,"tokensRemaining",String.valueOf(currentTokens));
                    }else{
                        super.hset(key,"tokensRemaining",String.valueOf(currentTokens-1));
                    }
                }
            }
        });
    }
    public final boolean access(final String identity,final long intervalInMills, final long capacity) {
        String key = "rate:limiter:" + intervalInMills + ":" + capacity + ":" + identity;
        Jedis jedis = RedisHelper.getInstance().getJedisWithOutShare();

        try{
            long result = (long) jedis.evalsha(scriptSha1, 1, key,
                    String.valueOf(intervalInMills / (capacity*1.0f)),
                    String.valueOf(System.currentTimeMillis()),
                    String.valueOf(capacity),
                    String.valueOf(intervalInMills));

            return result == 1L;
        }catch (Exception e){
            return false;
        }finally {
            RedisHelper.getInstance().returnResource(jedis);
        }
    }

    static String read(String luaScriptPath) throws IOException {
        Path path = Paths.get(luaScriptPath);
        byte[] scriptBytes = Files.readAllBytes(path);
        return new String(scriptBytes);
    }


    public static void main(String[] args) throws IOException, InterruptedException {
//        for(int i= 0;i<6;i++) {
//            System.out.println(LuaRateLimiterUtil.getInstance().access("test",6000L,3L));
//        }
//        Thread.sleep(2000L);
//        for(int i= 0;i<6;i++) {
//            System.out.println(LuaRateLimiterUtil.getInstance().access("test",6000L,3L));
//        }
        System.out.println(LuaRateLimiterUtil.getInstance().multi("test",6000L,3L));
    }
}
